{ "cells": [ { "cell_type": "markdown", "id": "b5e1d292", "metadata": {}, "source": [ "## Ay 119 Transformers Exercises\n", "\n", "Matthew Graham, 2026\n", "\n", "These exercises will go through various aspects of the transformer architecture to give you a better feel for it." ] }, { "cell_type": "markdown", "id": "d427dea9", "metadata": {}, "source": [ "### Setup\n", "\n", "Please note that exercises 4 - 6 require PyTorch." ] }, { "cell_type": "code", "execution_count": 3, "id": "c1f4b022", "metadata": {}, "outputs": [], "source": [ "import math\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "np.random.seed(7)\n", "\n", "try:\n", " import torch\n", " import torch.nn as nn\n", " import torch.nn.functional as F\n", " from torch.utils.data import Dataset, DataLoader, random_split\n", " TORCH_AVAILABLE = True\n", " torch.manual_seed(7)\n", "except Exception as e:\n", " TORCH_AVAILABLE = False\n", " print(\"PyTorch is not available. Exercises 4--6 require PyTorch.\")\n", " print(\"Import error:\", repr(e))\n", "\n", "plt.rcParams[\"figure.figsize\"] = (7, 4)\n", "plt.rcParams[\"axes.grid\"] = True\n" ] }, { "cell_type": "markdown", "id": "815a626c", "metadata": {}, "source": [ "### Transformer help\n", "\n", "A possible way to write a regressor or classifier using a transformer is something like this:" ] }, { "cell_type": "code", "execution_count": null, "id": "40e48d7b", "metadata": {}, "outputs": [], "source": [ "class TinyTransformer(nn.Module):\n", " \"\"\"\n", " Small transformer model for sequence-level classification or regression.\n", "\n", " Input shape:\n", " x: (batch_size, sequence_length, input_dim)\n", "\n", " Output:\n", " classification: logits of shape (batch_size, output_dim)\n", " regression: predictions of shape (batch_size, output_dim)\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " input_dim=3,\n", " output_dim=1,\n", " task=\"regression\", # \"regression\", \"classification\"\n", " d_model=48,\n", " nhead=4,\n", " num_layers=2,\n", " dim_feedforward=96,\n", " dropout=0.1,\n", " ):\n", " super().__init__()\n", "\n", " assert task in [\"regression\", \"classification\"]\n", "\n", " self.task = task\n", "\n", " # Project input features, e.g. (time, flux, flux_err), into transformer dimension\n", " self.input_proj = nn.Linear(input_dim, d_model)\n", "\n", " # Learnable summary token, analogous to [CLS] in BERT\n", " self.cls = nn.Parameter(torch.zeros(1, 1, d_model))\n", "\n", " # Transformer encoder processes the whole sequence\n", " enc_layer = nn.TransformerEncoderLayer(\n", " d_model=d_model,\n", " nhead=nhead,\n", " dim_feedforward=dim_feedforward,\n", " dropout=dropout,\n", " batch_first=True,\n", " activation=\"gelu\",\n", " )\n", "\n", " self.encoder = nn.TransformerEncoder(\n", " enc_layer,\n", " num_layers=num_layers,\n", " )\n", "\n", " # Final prediction head applied to the CLS representation\n", " layers = [\n", " nn.LayerNorm(d_model),\n", " nn.Linear(d_model, output_dim),\n", " ]\n", "\n", " # This is actually bounded regression \n", " if task == \"regression\":\n", " layers.append(nn.Sigmoid())\n", "\n", " self.head = nn.Sequential(*layers)\n", "\n", " def forward(self, x):\n", " B = x.shape[0]\n", "\n", " h = self.input_proj(x)\n", "\n", " cls_token = self.cls.expand(B, -1, -1)\n", " h = torch.cat([cls_token, h], dim=1)\n", "\n", " z = self.encoder(h)\n", "\n", " cls_representation = z[:, 0]\n", "\n", " return self.head(cls_representation)" ] }, { "cell_type": "markdown", "id": "e284ae38", "metadata": {}, "source": [ "## 1. From Sequences to Attention\n", "\n", "Astronomical light curves often contain events separated by long time gaps: repeated flares, echoes, quasi-periodic structure, etc. A recurrent model must pass information step-by-step through the sequence. Self-attention instead allows every point to directly compare itself with every other point.\n", "\n", "We start with a toy sequence:\n", "\n", "$x = [1, 0, 0, 0, 1]$\n", "\n", "For this exercise, use identity projections, so $(Q = K = V = x)$. The attention score between entries $i$ and $j$ is\n", "\n", "$s_{ij} = q_i k_j$\n", "\n", "Then apply a row-wise softmax to convert scores into attention weights.\n", "\n", "1. Explain why a model might need to relate the first and last entries.\n", "2. Compute the attention score matrix by hand.\n", "3. Compare your answer to the code below.\n", "4. Interpret which positions attend strongly to which other positions.\n" ] }, { "cell_type": "markdown", "id": "2dddf9fd", "metadata": {}, "source": [ "## 2. Implement Scaled Dot-Product Attention\n", "\n", "The core attention operation is\n", "\n", "$\n", "\\operatorname{Attention}(Q,K,V)\n", "= \\operatorname{softmax}\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V.\n", "$\n", "\n", "The factor $\\sqrt{d_k}$ prevents dot products from becoming too large when the key/query dimension is large.\n", "\n", "\n", "1. Fill in the function `scaled_dot_product_attention`.\n", "2. Compute the first row by hand for the provided matrices.\n", "3. Verify that your manual result matches the code output." ] }, { "cell_type": "markdown", "id": "9b629f83", "metadata": {}, "source": [ "For the third query vector $[1,1]$, why does the model spread attention across all three keys? What would happen if the scale factor $\\sqrt{d_k}$ were removed?" ] }, { "cell_type": "markdown", "id": "0a4dd3f1", "metadata": {}, "source": [ "## 3. Positional Encoding for Irregular Astronomical Time Sampling\n", "\n", "Transformers need information about order or time. In language, tokens are often evenly spaced. In astronomy, observations are irregular: weather, visibility windows, moon phase, instrument schedules, and survey strategy all create gaps.\n", "\n", "A common sinusoidal positional encoding is\n", "\n", "$\n", "PE(t, 2i) = \\sin\\left(t / 10000^{2i/d}\\right), \\quad\n", "PE(t, 2i+1) = \\cos\\left(t / 10000^{2i/d}\\right).\n", "$\n", "\n", "Here we will apply it to continuous observing times rather than integer token positions.\n", "\n", "1. Generate uniform and irregular timestamps.\n", "2. Compute positional encodings for both.\n", "3. Explain how large gaps appear in encoding space.\n", "4. Consider why MJD should often be shifted or normalized before encoding.\n" ] }, { "cell_type": "markdown", "id": "cede0a92", "metadata": {}, "source": [ "Try replacing $t$ with $log(1 + t)$ or with time differences $\\Delta t_i = t_i - t_{i-1}$. Which representation might be better for a survey light curve with seasonal gaps?" ] }, { "cell_type": "markdown", "id": "cf4ef859", "metadata": {}, "source": [ "## 4. Transformer for Light Curve Classification\n", "\n", "We will now create a synthetic dataset with four classes to model AGN time series:\n", "\n", "0. Damped Random Walk-like stochastic variability\n", "1. Periodic sinusoidal variability\n", "2. Single flare\n", "3. Flare followed by a dip\n", "\n", "Each observation has three input features: time, flux, flux error: $(t, f, \\sigma_f)$.\n", "\n", "This is intentionally simplified. Real survey data require masks, missing bands, calibration checks, outlier handling, and selection effects.\n", "\n", "1. Inspect the simulated light curves.\n", "2. Train the transformer classifier.\n", "3. Compare performance across classes.\n", "4. Identify which classes are easiest or hardest.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "7876c72d", "metadata": {}, "outputs": [], "source": [ "def simulate_ou(t, tau=150.0, sigma=0.08):\n", " \"\"\"Simple OU/DRW-like process on irregular timestamps.\"\"\"\n", " y = np.zeros_like(t, dtype=float)\n", " y[0] = np.random.normal(0, sigma)\n", " for i in range(1, len(t)):\n", " dt = t[i] - t[i-1]\n", " phi = np.exp(-dt / tau)\n", " var = sigma**2 * (1 - phi**2)\n", " y[i] = phi * y[i-1] + np.random.normal(0, np.sqrt(max(var, 1e-12)))\n", " return y" ] }, { "cell_type": "markdown", "id": "9c42ba2a", "metadata": {}, "source": [ "Which class is most often confused with another class?" ] }, { "cell_type": "markdown", "id": "e6f46c05", "metadata": {}, "source": [ "## 5. Interpreting Attention in Astronomical Context\n", "\n", "The Standard PyTorch `TransformerEncoderLayer` does not easily expose attention weights. For clarity, we will build a small one-layer model whose attention weights can be inspected directly.\n", "\n", "The goal is not to claim that attention weights are a perfect explanation. Instead, we want to test whether a model focuses on physically meaningful parts of a light curve: peaks, dips, periodic extrema, or seasonal gaps.\n", "\n", "1. Train or load a model that exposes attention weights: think about a one layer (multi-head) attention classifier\n", "2. Plot attention from the class token to each observation.\n", "3. Overlay those weights on the light curve.\n", "4. Discuss whether the attention map is astrophysically sensible.\n" ] }, { "cell_type": "markdown", "id": "c2b3e123", "metadata": {}, "source": [ "Does the model attend to physically interesting bits of the light curve? How would your interpretation change if the high attention points were isolated outliers?\n" ] }, { "cell_type": "markdown", "id": "a9e95187", "metadata": {}, "source": [ "## 6. Transformer + Physical Model Hybrid\n", "\n", "Pure black-box prediction is flexible but can be hard to interpret. A hybrid model predicts parameters of a simple physical or phenomenological light curve model.\n", "\n", "Here the model predicts parameters of an exponential decay flare:\n", "\n", "$ F(t) = A \\exp\\left[-\\frac{t - t_0}{\\tau}\\right], \\quad t \\ge t_0 $\n", "\n", "For simplicity, we generate flare-only light curves and train a transformer to infer $(A, t_0, \\tau)$. This is a toy version of amortized inference: a neural network maps data directly to physical parameters.\n", "\n", "1. Generate flare light curves with known parameters.\n", "2. Train the parameter-prediction transformer.\n", "3. Compare predicted and true parameters.\n", "4. Discuss why this is more interpretable than direct class prediction.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6029f91d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }